#include <algorithm>
#include <memory>
#include "catalog/column.h"
#include "catalog/schema.h"
#include "common/exception.h"
#include "common/macros.h"
#include "execution/expressions/column_value_expression.h"
#include "execution/expressions/comparison_expression.h"
#include "execution/expressions/constant_value_expression.h"
#include "execution/expressions/logic_expression.h"
#include "execution/plans/abstract_plan.h"
#include "execution/plans/filter_plan.h"
#include "execution/plans/hash_join_plan.h"
#include "execution/plans/nested_loop_join_plan.h"
#include "execution/plans/projection_plan.h"
#include "optimizer/optimizer.h"
#include "type/type_id.h"

namespace bustub {

auto Optimizer::OptimizeNLJAsHashJoin(const AbstractPlanNodeRef &plan) -> AbstractPlanNodeRef {
  std::vector<AbstractPlanNodeRef> children;
  for (const auto &child : plan->GetChildren()) {
    children.emplace_back(OptimizeNLJAsHashJoin(child));
  }
  auto optimized_plan = plan->CloneWithChildren(std::move(children));

  if (optimized_plan->GetType() == PlanType::NestedLoopJoin) {
    const auto &nlj_plan = dynamic_cast<const NestedLoopJoinPlanNode &>(*optimized_plan);
    // Has exactly two children
    BUSTUB_ENSURE(nlj_plan.children_.size() == 2, "NLJ should have exactly 2 children.");

    // Check if expr is equal condition where one is for the left table, and one is for the right table.
    if (const auto *expr = dynamic_cast<const ComparisonExpression *>(nlj_plan.Predicate().get()); expr != nullptr) {
      // <column expr> = <column expr>
      // left.col0=right.col0
      // right.col0=left.col0
      if (expr->comp_type_ == ComparisonType::Equal) {
        if (const auto *left_expr = dynamic_cast<const ColumnValueExpression *>(expr->children_[0].get());
            left_expr != nullptr) {
          if (const auto *right_expr = dynamic_cast<const ColumnValueExpression *>(expr->children_[1].get());
              right_expr != nullptr) {
            // Ensure both exprs have tuple_id == 0
            auto left_expr_tuple_0 =
                std::make_shared<ColumnValueExpression>(0, left_expr->GetColIdx(), left_expr->GetReturnType());
            auto right_expr_tuple_0 =
                std::make_shared<ColumnValueExpression>(0, right_expr->GetColIdx(), right_expr->GetReturnType());
            // Now it's in form of <column_expr> = <column_expr>. Let's check if one of them is from the left
            // table, and the other is from the right table.
            // 左表谓词在左，右表谓词在右
            std::vector<AbstractExpressionRef> left;
            std::vector<AbstractExpressionRef> right;
            if (left_expr->GetTupleIdx() == 0 && right_expr->GetTupleIdx() == 1) {
              left.push_back(left_expr_tuple_0);
              right.push_back(right_expr_tuple_0);
              return std::make_shared<HashJoinPlanNode>(nlj_plan.output_schema_, nlj_plan.GetLeftPlan(),
                                                        nlj_plan.GetRightPlan(), std::move(left), std::move(right),
                                                        nlj_plan.GetJoinType());
            }
            // 左表谓词在右，右表谓词在左
            if (left_expr->GetTupleIdx() == 1 && right_expr->GetTupleIdx() == 0) {
              left.push_back(right_expr_tuple_0);
              right.push_back(left_expr_tuple_0);
              return std::make_shared<HashJoinPlanNode>(nlj_plan.output_schema_, nlj_plan.GetLeftPlan(),
                                                        nlj_plan.GetRightPlan(), std::move(left), std::move(right),
                                                        nlj_plan.GetJoinType());
            }
          }
        }
      }
    }
    if (const auto *expr = dynamic_cast<const LogicExpression *>(nlj_plan.Predicate().get()); expr != nullptr) {
      // <column expr> = <column expr> AND <column expr> = <column expr>

      if (expr->logic_type_ == LogicType::And) {
        if (const auto *comp_expr_0 = dynamic_cast<const ComparisonExpression *>(expr->children_[0].get());
            comp_expr_0 != nullptr) {
          if (const auto *comp_expr_1 = dynamic_cast<const ComparisonExpression *>(expr->children_[1].get());
              comp_expr_1 != nullptr) {
            if (comp_expr_0->comp_type_ == ComparisonType::Equal && comp_expr_1->comp_type_ == ComparisonType::Equal) {
              const auto *left_expr0 = dynamic_cast<const ColumnValueExpression *>(comp_expr_0->children_[0].get());
              const auto *right_expr0 = dynamic_cast<const ColumnValueExpression *>(comp_expr_0->children_[1].get());
              const auto *left_expr1 = dynamic_cast<const ColumnValueExpression *>(comp_expr_1->children_[0].get());
              const auto *right_expr1 = dynamic_cast<const ColumnValueExpression *>(comp_expr_1->children_[1].get());
              if (left_expr0 != nullptr && right_expr0 != nullptr && left_expr1 != nullptr && right_expr1 != nullptr) {
                auto left_expr0_tuple_0 =
                    std::make_shared<ColumnValueExpression>(0, left_expr0->GetColIdx(), left_expr0->GetReturnType());
                auto right_expr0_tuple_0 =
                    std::make_shared<ColumnValueExpression>(0, right_expr0->GetColIdx(), right_expr0->GetReturnType());
                auto left_expr1_tuple_0 =
                    std::make_shared<ColumnValueExpression>(0, left_expr1->GetColIdx(), left_expr1->GetReturnType());
                auto right_expr1_tuple_0 =
                    std::make_shared<ColumnValueExpression>(0, right_expr1->GetColIdx(), right_expr1->GetReturnType());
                // 左表谓词0在左(右表谓词0在右)，左表谓词1在左（右表谓词1在右）
                // 左表谓词0在左，左表谓词1在右
                // 左表谓词0在右，左表谓词1在左
                // 左表谓词0在右，左表谓词1在右
                if (left_expr0->GetTupleIdx() == 0 && right_expr0->GetTupleIdx() == 1 &&
                    left_expr1->GetTupleIdx() == 0 && right_expr1->GetTupleIdx() == 1) {
                  std::vector<AbstractExpressionRef> left;
                  std::vector<AbstractExpressionRef> right;
                  left.push_back(left_expr0_tuple_0);
                  left.push_back(left_expr1_tuple_0);
                  right.push_back(right_expr0_tuple_0);
                  right.push_back(right_expr1_tuple_0);
                  return std::make_shared<HashJoinPlanNode>(nlj_plan.output_schema_, nlj_plan.GetLeftPlan(),
                                                            nlj_plan.GetRightPlan(), std::move(left), std::move(right),
                                                            nlj_plan.GetJoinType());
                }
                if (left_expr0->GetTupleIdx() == 0 && right_expr0->GetTupleIdx() == 1 &&
                    left_expr1->GetTupleIdx() == 1 && right_expr1->GetTupleIdx() == 0) {
                  std::vector<AbstractExpressionRef> left;
                  std::vector<AbstractExpressionRef> right;
                  left.push_back(left_expr0_tuple_0);
                  left.push_back(right_expr1_tuple_0);
                  right.push_back(right_expr0_tuple_0);
                  right.push_back(left_expr1_tuple_0);
                  return std::make_shared<HashJoinPlanNode>(nlj_plan.output_schema_, nlj_plan.GetLeftPlan(),
                                                            nlj_plan.GetRightPlan(), std::move(left), std::move(right),
                                                            nlj_plan.GetJoinType());
                }

                if (left_expr0->GetTupleIdx() == 1 && right_expr0->GetTupleIdx() == 0 &&
                    left_expr1->GetTupleIdx() == 0 && right_expr1->GetTupleIdx() == 1) {
                  std::vector<AbstractExpressionRef> left;
                  std::vector<AbstractExpressionRef> right;
                  left.push_back(right_expr0_tuple_0);
                  left.push_back(left_expr1_tuple_0);
                  right.push_back(left_expr0_tuple_0);
                  right.push_back(right_expr1_tuple_0);
                  return std::make_shared<HashJoinPlanNode>(nlj_plan.output_schema_, nlj_plan.GetLeftPlan(),
                                                            nlj_plan.GetRightPlan(), std::move(left), std::move(right),
                                                            nlj_plan.GetJoinType());
                }

                if (left_expr0->GetTupleIdx() == 1 && right_expr0->GetTupleIdx() == 1 &&
                    left_expr1->GetTupleIdx() == 1 && right_expr1->GetTupleIdx() == 0) {
                  std::vector<AbstractExpressionRef> left;
                  std::vector<AbstractExpressionRef> right;
                  left.push_back(right_expr0_tuple_0);
                  left.push_back(right_expr1_tuple_0);
                  right.push_back(left_expr0_tuple_0);
                  right.push_back(left_expr1_tuple_0);
                  return std::make_shared<HashJoinPlanNode>(nlj_plan.output_schema_, nlj_plan.GetLeftPlan(),
                                                            nlj_plan.GetRightPlan(), std::move(left), std::move(right),
                                                            nlj_plan.GetJoinType());
                }
              }
            }
          }
        }
      }
    }
  }
  return plan;
}

}  // namespace bustub
